-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ravel_pytree
now produces jit-compatible unravel functions
#13834
ravel_pytree
now produces jit-compatible unravel functions
#13834
Conversation
Ping @mattjj |
Thanks for this, Patrick. Sorry I've been so slow to respond for the last N months. (I wish I could say only "recently"!) This is a great idea. I adapted it in #14954 to reuse the Thanks for this improvement! |
Haha! Thanks, glad to see this in. FWIW the docs on |
Good catch! Also I done goofed in another way: I was thinking of |
Hah, I feel like the various kinds of JAX-internal function wrappers are starting to get a bit complicated. Off the top of my head there's Outside of core JAX I just use Side note, the issue fixed in this PR is a pretty common one in JAX -- e.g. |
You might be a lumper. Internal JAX utilities are set up to make lumpers out themselves. 😛 |
We effectively merged this as #14954! |
Previously,
would unecessarily induce recompilation.